#!/usr/bin/env python3

from boututils.datafile import DataFile
import itertools
import time
import numpy as np
from boututils.run_wrapper import launch_safe, shell_safe

#requires: all_tests
#requires: netcdf


class timer(object):
    """Context manager for printing how long a command took

    """
    def __init__(self, msg):
        self.msg = msg

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, exc_type, exc_value, traceback):
        end = time.time()
        print("{:12.8f}s {}".format(end - self.start, self.msg))


def timed_shell_safe(cmd, *args, **kwargs):
    """Wraps shell_safe in a timer

    """
    with timer(cmd):
        shell_safe(cmd, *args, **kwargs)


def timed_launch_safe(cmd, *args, **kwargs):
    """Wraps launch_safe in a timer

    """
    with timer(cmd):
        launch_safe(cmd, *args, **kwargs)


def verify(f1, f2):
    """Verifies that two BOUT++ files are identical

    """
    with timer("verify %s %s" % (f1, f2)):
        d1 = DataFile(f1)
        d2 = DataFile(f2)
        for v in d1.keys():
            if d1[v].shape != d2[v].shape:
                raise RuntimeError("shape mismatch in ", v, d1[v], d2[v])
            if v in ["MXSUB", "MYSUB", "NXPE", "NYPE", "iteration","wall_time"]:
                continue
            if v.startswith("wtime") or v.startswith("ncalls"):
                continue
            if not np.allclose(d1[v], d2[v], equal_nan=True):
                err = ""
                dimensions = [range(x) for x in d1[v].shape]
                for i in itertools.product(*dimensions):
                    if d1[v][i] != d2[v][i]:
                        err += "{}: {} != {}\n".format(i, d1[v][i], d2[v][i])
                raise RuntimeError("data mismatch in ", v, err, d1[v], d2[v])


timed_shell_safe("make")

# Run once to get normal data
timed_shell_safe("./squash -q -q -q nout=2")
timed_shell_safe("mv data/BOUT.dmp.0.nc f1.nc")

# Parallel test
timed_shell_safe("rm -f f2.nc")
timed_launch_safe("./squash -q -q -q nout=2", nproc=4, mthread=1)
timed_shell_safe("../../../bin/bout-squashoutput -qdcl 9 data --outputname ../f2.nc")

verify("f1.nc", "f2.nc")

# Parallel and in two pieces
timed_shell_safe("rm -f f2.nc")
timed_launch_safe("./squash -q -q -q", nproc=4, mthread=1)
timed_shell_safe("../../../bin/bout-squashoutput -qdcl 9 data --outputname ../f2.nc")
timed_launch_safe("./squash -q -q -q restart", nproc=4, mthread=1)
timed_shell_safe("../../../bin/bout-squashoutput -qdcal 9 data --outputname ../f2.nc")

verify("f1.nc", "f2.nc")

# Parallel and in two pieces without dump_on_restart
timed_shell_safe("rm -f f2.nc")
timed_launch_safe("./squash -q -q -q", nproc=4, mthread=1)
timed_shell_safe("../../../bin/bout-squashoutput -qdcl 9 data --outputname ../f2.nc")
timed_launch_safe("./squash -q -q -q restart dump_on_restart=false", nproc=4, mthread=1)
timed_shell_safe("../../../bin/bout-squashoutput -qdcal 9 data --outputname ../f2.nc")

verify("f1.nc", "f2.nc")

# Sequential test
timed_shell_safe("rm -f f2.nc")
timed_shell_safe("./squash -q -q -q nout=2")
timed_shell_safe("../../../bin/bout-squashoutput -qdcl 9 data --outputname ../f2.nc")

verify("f1.nc", "f2.nc")
